import argparse
import csv
import json
import os
import random
import re
import sys
import time
from typing import Dict, List, Optional

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm

# --------------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------------
model, tokenizer = None, None

persona_prompts = {
    "18-24_female": """You are a digital content analyst who is also a woman aged 18–24. You intuitively understand what resonates with your generation—emotional authenticity, aesthetic appeal, individuality, and social relevance. You're highly fluent in trends like TikTok culture, meme literacy, mental health conversations, and empowerment messaging.

You will be shown 6 example tweets—3 that received **high likes** and 3 that received **low likes** from women in your age group. These are real engagement outcomes and serve as ground truth benchmarks. Compare the new tweet to these examples to guide your prediction.

Think step by step. First, explain your reasoning by comparing the tweet to both high- and low-performing examples and how it aligns with your generation's values. Then, conclude with:

Reason: [Your reasoning]  
Answer: [High / Low]""",

    "18-24_male": """You are a digital content analyst and a man aged 18–24. You understand the humor, boldness, and trend awareness that appeal to young men today. You're fluent in gaming references, meme culture, edginess, and influencer-driven language.

You will see 6 example tweets—3 with **high likes** and 3 with **low likes** among men in your age group. These are based on real performance data. Use them as reference to judge the new tweet.

Think step by step. First, explain your reasoning based on similarities or differences with the examples and how they connect with your generation's humor or interests. Then conclude with:

Reason: [Your reasoning]  
Answer: [High / Low]""",

    "25-34_female": """You are a digital content analyst who is also a woman aged 25–34. You understand that this demographic seeks a balance between ambition, self-care, relationships, and lifestyle goals. Aesthetic clarity, authenticity, empowerment, and intelligent humor tend to resonate.

You will be given 6 example tweets—3 that received **high likes** and 3 that received **low likes** from women aged 25–34. These are real examples. Use them to evaluate how the new tweet compares.

Think step by step. First explain your reasoning, then conclude with:

Reason: [How the tweet aligns with or diverges from the examples and your generation's values]  
Answer: [High / Low]""",

    "25-34_male": """You are a digital content analyst and a man aged 25–34. You recognize that this demographic responds to tweets that are direct, witty, aspirational, or offer insight into tech, fitness, finance, or personal growth. You value clarity and cleverness over fluff.

You will be shown 6 example tweets (3 high-performing, 3 low-performing) based on real engagement from your demographic. Compare the new tweet carefully.

Think step by step. First explain your reasoning in relation to the examples and what your generation values. Then conclude with:

Reason: [Your analysis]  
Answer: [High / Low]""",

    "35-44_female": """You are a digital content analyst and a woman aged 35–44. You understand your generation values emotional intelligence, practical wisdom, family, and health. Tweets that offer warmth, relatability, humor grounded in real life, or meaningful advice tend to perform best.

You'll be shown 6 example tweets—3 that got **high likes** and 3 that got **low likes** from women 35–44. These are ground truth signals. Use them to reason about the new tweet.

Think step by step. Start with your reasoning, then provide your final judgment:

Reason: [Comparison to examples and fit with your generation's mindset]  
Answer: [High / Low]""",

    "35-44_male": """You are a digital content analyst and a man aged 35–44. You know your generation appreciates authenticity, practical humor, and substance. Career, family, health, and finance topics—when treated with respect and clarity—tend to earn strong engagement.

You will be shown 6 grounded examples—3 high-performing and 3 low-performing tweets for your demographic. Use these to assess the next tweet.

Think step by step. Compare thoughtfully, then conclude:

Reason: [Your demographic-specific analysis]  
Answer: [High / Low]""",

    "45-54_female": """You are a digital content analyst and a woman aged 45–54. Your generation values trust, clarity, emotional depth, and lived experience. Wellness, family, community, and resilience are key themes that resonate.

You'll be shown 6 tweets with real performance outcomes—3 that received **high likes**, and 3 that received **low likes** from women 45–54. Use these to assess the next tweet.

Think step by step. Explain your evaluation, then provide:

Reason: [Your comparative reasoning and generational fit]  
Answer: [High / Low]""",

    "45-54_male": """You are a digital content analyst and a man aged 45–54. You've seen many cultural trends come and go, and you appreciate sincerity, intelligence, and practical messaging. Health, family, finance, and meaningful humor tend to engage your peers.

You will be given 6 examples—3 tweets with **high likes** and 3 with **low likes**, based on real engagement by men aged 45–54. Use them to compare the new tweet.

Think step by step. Give your reasoning first, then your prediction:

Reason: [Your analysis based on values and comparison to the examples]  
Answer: [High / Low]""",

    "55+_female": """You are a seasoned digital content analyst and a woman over 55. You see content through decades of shifting cultural values. You and your peers favor messages that are clear, emotionally resonant, and meaningful, centered around wellness, family, security, and community.

You will receive 6 grounded tweet examples—3 high-engagement and 3 low-engagement tweets from women 55+. Use these benchmarks to analyze the new tweet.

Think step by step. Begin with your reasoning, then give your conclusion:

Reason: [Your logic based on life experience, examples, and values]  
Answer: [High / Low]""",

    "55+_male": """You are a digital content analyst and a man over 55. You've witnessed the evolution of media and appreciate messaging that is sincere, wise, clear, and grounded in family, health, and security values.

You will be shown 6 real-life examples—3 tweets that performed well and 3 that did not among men over 55. Use them to guide your judgment.

Think step by step. Reflect on the tone, message, and relevance, then write:

Reason: [Your reasoning and comparison]  
Answer: [High / Low]"""
}

# Strengthen response format requirements for all personas for tweet task
for _k in persona_prompts:
    persona_prompts[_k] += (
        "\n\nIMPORTANT: Provide your response in exactly two lines:\n"
        "Reason: <brief justification>\n"
        "Answer: [High / Low] (predict if the tweet will get high or low likes)\n"
        "Only output 'High' or 'Low' after 'Answer:'."
    )

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(user_prompt: str, sys_prompt: str, model, tokenizer, args) -> str:
    """Call Qwen model for chat completion."""
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": user_prompt},
    ]
    # Use apply_chat_template
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    )
    input_ids = input_ids.to(model.device)  # Ensure input_ids are on the same device as the model

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=1200,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1
        )
    
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run static-persona evaluation for tweet engagement.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="static_folder", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--dataset_paths", type=str, required=True, help="Comma-separated list of *.jsonl datasets to evaluate.")
    parser.add_argument("--max_examples", type=int, default=None, help="(Optional) truncate dataset to this many examples – useful for quick smoke tests.")
    parser.add_argument("--similarity_json", type=str, default=None, help="Path to JSON with pre-computed nearest neighbours.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic (chunk mode)
# --------------------------------------------------------------------------------------

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # -----------------------------------------------------------------------------
    # Load Model once at the start
    # -----------------------------------------------------------------------------
    model_name = "Qwen/Qwen3-32B" #change to meta-llama/Llama-3.3-70B-Instruct for LlaMA
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )

    # Tweet evaluation is the only task, so we run it directly.
    sim_map = _get_similarity_map(args)
    run_tweet_evaluation(args, sim_map=sim_map)

def _get_similarity_map(args):
    """Load precomputed similarity map if provided, else return None."""
    sim_map = None
    if args.similarity_json:
        if not os.path.isfile(args.similarity_json):
            print(f"[WARNING] --similarity_json provided but file not found: {args.similarity_json}")
        else:
            with open(args.similarity_json, "r", encoding="utf-8") as _f:
                sim_map = json.load(_f)
            print(f"Loaded pre-computed similarity map from {args.similarity_json} (entries: {len(sim_map)})")
    return sim_map

def _tweet_key(rec, idx):
    """Return a unique key for a tweet record for similarity lookup. By default, use the index as string."""
    return str(idx)

def _extract_brand_and_date(text: str):
    """Very lightweight extraction of brand name and year from *text*.

    This utility tries to identify:
    1. A brand name specified as `Brand: <XYZ>` (case-insensitive, stops at whitespace).
    2. A four-digit year between 1900–2099.

    If either element is not found we fall back to the string "unknown" so that
    downstream accuracy statistics buckets remain well-formed.
    """
    # Brand
    brand_match = re.search(r"brand\s*:\s*([A-Za-z0-9_\-]+)", text, flags=re.IGNORECASE)
    brand = brand_match.group(1).lower() if brand_match else "unknown"

    # Year
    year_match = re.search(r"\b(19|20)\d{2}\b", text)
    year = year_match.group(0) if year_match else "unknown"

    return brand, year

def _extract_tweet_content(text: str) -> str:
    """Given the raw *text* field from the dataset, try to extract just the tweet body.

    The dataset prompt looks like:

        Given a tweet of pfizer posted by the account PfizerMed on 2022-10-27. Tweet : <ACTUAL TWEET TEXT>.Verbalisation of media content: ...

    We want to capture everything between "Tweet :" and the start of the verbalisation/meta section (which usually begins with ".Verbalisation" or "Verbalisation").

    If we cannot find the pattern we fall back to returning the original *text*.
    """
    # Non-greedy match between "Tweet :" (or "Tweet:") and either ".Verbalisation" or end-of-string.
    m = re.search(r"Tweet\s*:?\s*(.*?)(?:\.\s*Verbalisation|Verbalisation|$)", text, flags=re.IGNORECASE | re.DOTALL)
    if m:
        return m.group(1).strip()
    return text.strip()

def run_tweet_evaluation(args, sim_map=None):
    """End-to-end evaluation on tweet-like datasets containing {"prompt":..., "response":...} per line.
    If sim_map is provided, use it for few-shot neighbor selection; else fall back to random sampling."""
    global model, tokenizer
    # Resolve dataset paths
    dset_paths = [p.strip() for p in args.dataset_paths.split(",") if p.strip()]

    # ------------------------------------------------------------------
    # Load Qwen model once - REMOVED, model is loaded in main()
    # ------------------------------------------------------------------
    overall_out_dir = args.output_dir or "tweet_static_results"
    os.makedirs(overall_out_dir, exist_ok=True)

    for dpath in dset_paths:
        dataset_name = os.path.basename(dpath)
        print(f"\n[INFO] Processing dataset: {dataset_name}")

        records = []
        with open(dpath, "r", encoding="utf-8") as f_in:
            for line_idx, line in enumerate(f_in):
                if args.max_examples and line_idx >= args.max_examples:
                    break
                try:
                    records.append(json.loads(line))
                except Exception:
                    continue  # skip malformed

        # --- Apply slicing if --start/--end are provided ---
        slice_start = max(0, args.start) if hasattr(args, 'start') and args.start is not None else 0
        slice_end = args.end if hasattr(args, 'end') and args.end is not None else len(records) - 1
        slice_end = min(slice_end, len(records) - 1)
        if slice_start > 0 or slice_end < len(records) - 1:
            records = records[slice_start : slice_end + 1]
            print(f"[INFO] Processing slice {slice_start}-{slice_end} (n={len(records)}) of {dataset_name}")
        else:
            print(f"[INFO] Processing full dataset {dataset_name} (n={len(records)})")

        slice_suffix = f"_{slice_start}_{slice_end}" if 'slice_start' in locals() else ""
        out_path = os.path.join(overall_out_dir, f"tweet_results_{dataset_name}{slice_suffix}.json")

        correct = 0
        brand_stats = {}
        time_stats = {}
        all_results = []

        # Precompute all possible indices for neighbor selection
        all_indices = list(range(len(records)))

        for idx, rec in enumerate(tqdm(records, desc=dataset_name)):
            prompt_text_raw = rec.get("prompt", "")
            prompt_text = _extract_tweet_content(prompt_text_raw)
            gt_resp = rec.get("response", "")

            gt_label = "high" if re.search(r"high likes", gt_resp, flags=re.IGNORECASE) else "low"

            # --- FEW-SHOT EXAMPLES: Always select 5 random, but ensure mix of high/low ---
            pool = [i for i in all_indices if i != idx]
            max_attempts = 10
            for attempt in range(max_attempts):
                neighbor_ids = random.sample(pool, k=min(5, len(pool)))
                labels = [1 if re.search(r"high likes", records[nid].get("response", ""), flags=re.IGNORECASE) else 0 for nid in neighbor_ids]
                if any(labels) and not all(labels):
                    break  # At least one high and one low
            else:
                # Fallback: force at least one high and one low if possible
                highs = [i for i in pool if re.search(r"high likes", records[i].get("response", ""), flags=re.IGNORECASE)]
                lows = [i for i in pool if not re.search(r"high likes", records[i].get("response", ""), flags=re.IGNORECASE)]
                neighbor_ids = []
                if highs: neighbor_ids.append(random.choice(highs))
                if lows: neighbor_ids.append(random.choice(lows))
                rest = [i for i in pool if i not in neighbor_ids]
                neighbor_ids += random.sample(rest, k=min(5-len(neighbor_ids), len(rest)))
            log_msg = f"[RANDOM_MIXED] Used random mixed neighbors for idx {idx}: {neighbor_ids}"

            example_blocks = []
            for sid in neighbor_ids:
                ex = records[sid]
                text = _extract_tweet_content(ex["prompt"])
                example_blocks.append(f"{text}")
            examples_text = "\n---\n".join(example_blocks)

            persona_outputs = {}
            persona_labels = []
            for persona_name, sys_prompt in persona_prompts.items():
                user_prompt = (
                    "Below are five example tweets. After these, you'll see a new tweet. "
                    "Predict whether it will receive high or low likes on Twitter. "
                    "Return two lines exactly:\nReason: <brief>\nAnswer: [High / Low]\n"
                    "Examples:\n" + examples_text +
                    "\n---\n" + prompt_text
                )

                resp_text = verbalize(user_prompt, sys_prompt, model, tokenizer, args)
                # Extract high/low from the answer line
                match = re.search(r"Answer:\s*(high|low)", resp_text, flags=re.IGNORECASE)
                label = match.group(1).lower() if match else None
                persona_outputs[persona_name] = {"response": resp_text, "label": label}
                if label is not None:
                    persona_labels.append(1 if label == "high" else 0)

            # Majority vote: maximum consensus for high or low
            high_count = sum(persona_labels)
            low_count = len(persona_labels) - high_count
            pred_label = "high" if high_count > low_count else "low"

            # Accuracy bookkeeping
            is_correct = pred_label == gt_label
            if is_correct:
                correct += 1

            brand, year = _extract_brand_and_date(prompt_text_raw)

            # Brand stats
            b_stats = brand_stats.setdefault(brand, {"correct": 0, "total": 0})
            b_stats["total"] += 1
            if is_correct:
                b_stats["correct"] += 1

            # Time stats (year)
            t_stats = time_stats.setdefault(year, {"correct": 0, "total": 0})
            t_stats["total"] += 1
            if is_correct:
                t_stats["correct"] += 1

            all_results.append({
                "prompt": prompt_text,
                "ground_truth": gt_label,
                "persona_predictions": persona_outputs,
                "predicted_label": pred_label,
                "neighbor_ids": neighbor_ids,
                "neighbor_log": log_msg,
            })

            # Incremental save after every example to avoid data loss
            try:
                with open(out_path, "w", encoding="utf-8") as f_out_inc:
                    json.dump(all_results, f_out_inc, indent=2)
            except Exception as _e:
                print(f"[WARNING] Incremental save failed: {_e}")

        # — Final save per-dataset results
        with open(out_path, "w", encoding="utf-8") as f_out:
            json.dump(all_results, f_out, indent=2)

        # — Report accuracies
        total = len(records)
        overall_acc = correct / total if total else 0.0
        print(f"Overall accuracy for {dataset_name}: {overall_acc:.3f} ({correct}/{total})")

        print("\nAccuracy by brand:")
        for b, st in sorted(brand_stats.items(), key=lambda x: x[0]):
            acc = st["correct"] / st["total"] if st["total"] else 0.0
            print(f"  {b}: {acc:.3f} ({st['correct']}/{st['total']})")

        print("\nAccuracy by year:")
        for y, st in sorted(time_stats.items(), key=lambda x: x[0]):
            acc = st["correct"] / st["total"] if st["total"] else 0.0
            print(f"  {y}: {acc:.3f} ({st['correct']}/{st['total']})")

    print("\n[INFO] Tweet evaluation complete.")

if __name__ == "__main__":
    main()


            

            

        
        
        

        
